{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 最简单的 CART 回归"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'sp_ind': 0, 'sp_val': 18.1, 'left': {'sp_ind': 6, 'sp_val': 6.232, 'left': {'sp_ind': 0, 'sp_val': 31.0, 'left': 4.693194444444444, 'right': {'sp_ind': 7, 'sp_val': 47.4, 'left': 9.302297297297297, 'right': 6.2104}}, 'right': {'sp_ind': 7, 'sp_val': 69.6, 'left': 14.383142857142857, 'right': 10.810444444444444}}, 'right': {'sp_ind': 8, 'sp_val': 1.8347, 'left': 17.94710843373494, 'right': 23.82507462686567}}\n"
     ]
    }
   ],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "'''\n",
    "@FilePath: /nlp-exercise/boosting_tree/cart.py\n",
    "@Author: huzhu\n",
    "@Date: 2020-05-31 18:08:15\n",
    "@Description: \n",
    "'''\n",
    "\n",
    "from sklearn.datasets import load_boston\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "boston = load_boston() # 波士顿的房价数据集\n",
    "\n",
    "dataset,label = (boston.data, boston.target)\n",
    "# 将 label 插入到 dataset 的最后一列\n",
    "dataset = np.insert(dataset, 0, values=label, axis=1)\n",
    "\n",
    "def binsplit_dataset(dataset, feature, value):\n",
    "    mat0 = dataset[np.nonzero(dataset[:,feature] > value)[0],:]\n",
    "    mat1 = dataset[np.nonzero(dataset[:,feature] <= value)[0],:]\n",
    "    return mat0, mat1\n",
    "\n",
    "\n",
    "def reg_leaf(dataset):\n",
    "    return np.mean(dataset[:,-1])\n",
    "\n",
    "\n",
    "def reg_err(dataset):\n",
    "    return np.var(dataset[:,-1]) * np.shape(dataset)[0]\n",
    "\n",
    "\n",
    "def create_tree(dataset, leaf_type=reg_leaf, err_type=reg_err,ops=(1,4)):\n",
    "    feat, val = choose_best_split(dataset, leaf_type, err_type, ops)\n",
    "    if feat == None:\n",
    "        return val\n",
    "    ret_tree = {}\n",
    "    ret_tree[\"sp_ind\"] = feat\n",
    "    ret_tree[\"sp_val\"] = val\n",
    "    lset, rset = binsplit_dataset(dataset, feat, val)\n",
    "    ret_tree[\"left\"] = create_tree(lset, leaf_type, err_type, ops)\n",
    "    ret_tree[\"right\"] = create_tree(rset, leaf_type, err_type, ops)\n",
    "    return ret_tree\n",
    "\n",
    "\n",
    "def choose_best_split(dataset, leaf_type=reg_leaf, err_type=reg_err, ops=(1,4)):\n",
    "    # 误差下降值，最小样本数\n",
    "    tolS = ops[0]; tolN = ops[1]\n",
    "    if len(set(dataset[:,-1].T.tolist())) == 1:\n",
    "        return None, leaf_type(dataset)\n",
    "    m ,n = np.shape(dataset)\n",
    "    S = err_type(dataset)\n",
    "    bestS = np.inf; beatindex = 0; bestvalue = 0\n",
    "    for feat_index in range(n - 1):\n",
    "        for split_val in set(dataset[:, feat_index]):\n",
    "            mat0, mat1 = binsplit_dataset(dataset, feat_index, split_val)\n",
    "            # 如果发现切分的样本数太少\n",
    "            if(np.shape(mat0)[0] < tolN or np.shape(mat1)[0] < tolN):\n",
    "                continue\n",
    "            newS = err_type(mat0) + err_type(mat1)\n",
    "            if newS < bestS:\n",
    "                beatindex = feat_index\n",
    "                bestvalue = split_val\n",
    "                bestS = newS\n",
    "    # 如果误差减小不大则退出\n",
    "    if (S - bestS) < tolS:\n",
    "        return None, leaf_type(dataset)\n",
    "    mat0, mat1 = binsplit_dataset(dataset, beatindex, bestvalue)\n",
    "    # 如果切分的数据集很小则退出\n",
    "    if (np.shape(mat0)[0] < tolN or np.shape(mat1)[0] < tolN):\n",
    "        return None, leaf_type(dataset)\n",
    "    return beatindex, bestvalue\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    tree = create_tree(dataset,leaf_type=reg_leaf, err_type=reg_err,ops=(1,50))\n",
    "    print(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn import tree\n",
    "from sklearn.externals.six import StringIO\n",
    "import pydotplus\n",
    "clf = tree.DecisionTreeRegressor()\n",
    "clf = clf.fit(x_train,y_train)\n",
    "dot_data = StringIO()\n",
    "tree.export_graphviz(clf,out_file = dot_data,filled=True,rounded=True,\n",
    "                 special_characters=True)\n",
    "graph = pydotplus.graph_from_dot_data(dot_data.getvalue())\n",
    "graph.write_pdf(\"dtr.pdf\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](dtr.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
